/***************************************************************************************************
 * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Unit tests for epilogues
*/
/**
 * \file test/unit/epilogue/threadblock/bias_add_testbed.h
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#pragma once

#include <fstream>

#include "../../common/cutlass_unit_test.h"

#include "cutlass/aligned_buffer.h"
#include "cutlass/complex.h"
#include "cutlass/half.h"

#include "cutlass/epilogue/thread/linear_combination.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace test {
namespace kernel {

template <typename Epilogue>
__global__ void epilogue_threadblock(
        typename Epilogue::OutputTileIterator::Params params_D,
        typename Epilogue::OutputTileIterator::Element* ptr_D,
        typename Epilogue::OutputTileIterator::Params params_C,
        typename Epilogue::OutputTileIterator::Element* ptr_C,
        typename Epilogue::BiasTileIterator::Params params_bias,
        typename Epilogue::BiasTileIterator::Element* ptr_bias,
        typename Epilogue::OutputOp::Params params_output_op,
        typename Epilogue::OutputTileIterator::LogicalCoord extent,
        cutlass::TensorRef<typename Epilogue::WarpMmaOperator::ElementC,
                           typename Epilogue::WarpMmaOperator::LayoutC>
                accumulator_ref,
        int epilogue_count = 1) {
    __shared__ typename Epilogue::SharedStorage shared_storage;

    int thread_idx = threadIdx.x;
    int warp_idx = threadIdx.x / 32;
    int lane_idx = threadIdx.x % 32;

    //
    // Construct the epilogue
    //

    // Tile iterator writing to output tile
    typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, extent,
                                                     thread_idx);

    // Tile iterator writing to output tile
    typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, extent,
                                                     thread_idx);

    // Bias tile iterator to read bias
    typename Epilogue::BiasTileIterator iterator_bias(params_bias, ptr_bias,
                                                      extent, thread_idx);

    // Epilogue operator
    Epilogue epilogue(shared_storage, thread_idx, warp_idx, lane_idx);

    //
    // Initialize the accumulators
    //

    int warp_mn =
            warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN);
    int warp_m = warp_mn % Epilogue::WarpCount::kM;
    int warp_n = warp_mn / Epilogue::WarpCount::kM;

    accumulator_ref.add_coord_offset(
            {warp_m * Epilogue::WarpMmaOperator::Shape::kM,
             warp_n * Epilogue::WarpMmaOperator::Shape::kN});

    typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(
            accumulator_ref, lane_idx);

    typename Epilogue::AccumulatorTile accumulators;
    accumulators.clear();
    accumulator_iterator.load(accumulators);

#if 0
  // For debugging, enable this block of code to fill each accumulator element with its
  // source thread ID.
  CUTLASS_PRAGMA_UNROLL
  for (int i = 0; i < accumulators.size(); ++i) {
    typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x);
    //typename Epilogue::WarpMmaOperator::ElementC x(i);
    accumulators[i] = x;
  }

  /*
  #pragma unroll 1
  for (int tid = 0; tid < 32; ++tid) {
    if (tid == thread_idx) {
      printf("\nT%d: ", thread_idx);
      CUTLASS_PRAGMA_UNROLL
      for (int i = 0; i < accumulators.size(); ++i) {
        printf("%d ", int(accumulators[i]));
      }  
    }
  }

  if (thread_idx == 0) {
    printf("\n\n");  
  }
  */

  __syncthreads();

#endif

    //
    // Perform the epilogue operation
    //

    typename Epilogue::OutputOp output_op(params_output_op);

    // Place the epilogue in a loop
    for (int iter = 0; iter < epilogue_count; ++iter) {
        epilogue(output_op, iterator_D, accumulators, iterator_bias,
                 iterator_C);
    }
}

}  // namespace kernel
}  // namespace test

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Epilogue_>
class EpilogueTestbed {
public:
    using Epilogue = Epilogue_;
    using ElementAccumulator = typename Epilogue::ElementAccumulator;
    using ElementCompute = typename Epilogue::OutputOp::ElementCompute;
    using ElementOutput = typename Epilogue::ElementOutput;
    using ElementBias = typename Epilogue::ElementBias;
    using OutputOpParams = typename Epilogue::OutputOp::Params;
    using LayoutOutput = typename Epilogue::OutputTileIterator::Layout;

public:
    //
    // Data members
    //

    cutlass::MatrixCoord quantized_size;
    cutlass::HostTensor<ElementAccumulator, cutlass::layout::RowMajor>
            accumulator_tensor;
    cutlass::HostTensor<ElementOutput, LayoutOutput> source_tensor;
    cutlass::HostTensor<ElementOutput, LayoutOutput> output_tensor;
    cutlass::HostTensor<ElementBias, LayoutOutput> bias_tensor;

public:
    //
    // Methods
    //

    EpilogueTestbed()
            : quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN),
              accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}),
              source_tensor({Epilogue::Shape::kM, 1, 1, Epilogue::Shape::kN}),
              output_tensor({Epilogue::Shape::kM, 1, 1, Epilogue::Shape::kN}),
              bias_tensor({1, 1, 1, Epilogue::Shape::kN}) {}

    bool run_all() {
        double alpha_values[] = {1.111318381, 0, 2.277113};
        double beta_values[] = {1.222138417, 1.11971, -1.18746114};
        double gamma_values[] = {0, 1.233137, -1.3287414};

        // Test runtime explodes if we tried to test every case exhaustively.
        // This tests the full output tile and several smaller sizes to stress
        // predication.
        for (int m_idx = 0; m_idx < 3; ++m_idx) {
            for (int n_idx = 0; n_idx < 3; ++n_idx) {
                int m = quantized_size.row() - m_idx * 4;
                int n = quantized_size.column() -
                        n_idx * Epilogue::kElementsPerAccess / 4;

                for (double const& alpha : alpha_values) {
                    for (double const& beta : beta_values) {
                        for (double const& gamma : gamma_values) {
                            bool passed = run(
                                    {n, 1, 1, m},
                                    {cutlass::from_real<ElementCompute>(alpha),
                                     cutlass::from_real<ElementCompute>(beta),
                                     cutlass::from_real<ElementCompute>(
                                             gamma)});

                            if (!passed) {
                                return false;
                            }
                        }
                    }
                }
            }
        }

        return true;
    }

    /// Runs the test
    bool run(cutlass::Tensor4DCoord extent, OutputOpParams output_params) {
        //
        // Initialize problem space
        //
        accumulator_tensor.reset({extent.c(), extent.n()});
        source_tensor.reset(extent);
        output_tensor.reset(extent);
        bias_tensor.reset({1, 1, 1, extent.c()});

        uint64_t seed = 2019;

        cutlass::reference::host::TensorFillRandomUniform(
                accumulator_tensor.host_view(), seed, 20, -20, 0);

        cutlass::reference::host::TensorFillRandomUniform(
                source_tensor.host_view(), seed + 2018, 20, -20, 0);

        cutlass::reference::host::TensorFillRandomUniform(
                bias_tensor.host_view(), seed + 2020, 20, -20, 0);

        /// cutlass::reference::host::TensorFill(
        ///     bias_tensor.host_view(), ElementBias(0));

        ElementOutput default_output = ElementOutput(-127);
        cutlass::reference::host::TensorFill(output_tensor.host_view(),
                                             default_output);

        accumulator_tensor.sync_device();
        output_tensor.sync_device();
        source_tensor.sync_device();
        bias_tensor.sync_device();

        //
        // Initialize epilogue parameters
        //

        typename Epilogue::OutputTileIterator::Params params_D(
                output_tensor.device_ref().layout());
        typename Epilogue::OutputTileIterator::Params params_C(
                source_tensor.device_ref().layout());
        typename Epilogue::BiasTileIterator::Params params_bias(
                bias_tensor.device_ref().layout());

        //
        // Launch kernel
        //

        dim3 grid(1, 1);
        dim3 block(Epilogue::WarpCount::kCount * 32, 1);

        test::kernel::epilogue_threadblock<Epilogue><<<grid, block>>>(
                params_D, output_tensor.device_data(), params_C,
                source_tensor.device_data(), params_bias,
                bias_tensor.device_data(), output_params,
                {extent.c(), extent.n()}, accumulator_tensor.device_view());

        cudaError_t result = cudaDeviceSynchronize();

        if (result != cudaSuccess) {
            std::cerr << "Kernel error: " << cudaGetErrorString(result)
                      << std::endl;
            return false;
        }

        //
        // Verify results
        //
        output_tensor.sync_host();

        int errors = 0;
        int const kMaxErrors = 5;

        for (int r = 0; errors < kMaxErrors && r < extent.c(); ++r) {
            for (int c = 0; errors < kMaxErrors && c < extent.n(); ++c) {
                cutlass::Tensor4DCoord coord{c, 0, 0, r};
                ElementOutput got = output_tensor.at(coord);

                ElementCompute intermediate =
                        output_params.alpha *
                                ElementCompute(accumulator_tensor.at(
                                        cutlass::MatrixCoord{r, c})) +
                        output_params.beta *
                                ElementCompute(bias_tensor.at(
                                        cutlass::Tensor4DCoord{0, 0, 0, r})) +
                        output_params.gamma *
                                ElementCompute(source_tensor.at(coord));
                intermediate = intermediate < -128.f ? -128.f : intermediate;
                intermediate = intermediate > 127.f ? 127.f : intermediate;
                ElementOutput expected =
                        ElementOutput(std::round(intermediate));

                if (expected != got) {
                    using OutputIO = cutlass::ScalarIO<ElementOutput>;

                    EXPECT_TRUE(false)
                            << "-------\n"
                            << "Error - output element (" << coord
                            << ") - expected: " << OutputIO(expected)
                            << ",  got: " << OutputIO(got) << std::endl;

                    ++errors;
                }
            }
        }

        //
        // Report results on error
        //

        if (errors) {
            std::stringstream ss;
            ss << "output_tensor_op_" << Epilogue::Shape::kM << "x"
               << Epilogue::Shape::kN << "_"
               << Epilogue::WarpTileIterator::WarpShape::kM << "x"
               << Epilogue::WarpTileIterator::WarpShape::kN << "_slice_"
               << Epilogue::WarpCount::kK << ".csv";

            std::ofstream output_file(ss.str());
            output_file << output_tensor.host_view();
        }

        return !errors;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////
